package com.deep.hystrix.utils;

import com.netflix.hystrix.strategy.concurrency.HystrixConcurrencyStrategy;
import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Field;
import java.util.concurrent.Callable;

/**
 * 获取当前线程的threadlocal变量，传给callable
 * callable在执行时，用传入的调用者的线程的threadloca，替换自己的，然后执行任务。
 * （callable执行前后，前后需要保存worker线程本身的线程局部变量以供恢复）
 */
@Slf4j
public class MyHystrixConcurrencyStrategy extends HystrixConcurrencyStrategy {

    @Override
    public <T> Callable<T> wrapCallable(Callable<T> callable) {
        /**
         * 获取当前线程的threadlocalmap
         */
        Object currentThreadlocalMap = getCurrentThreadlocalMap();

        Callable<T> finalCallable = new Callable<T>() {
            private Object callerThreadlocalMap = currentThreadlocalMap;

            private Callable<T> targetCallable = callable;

            @Override
            public T call() throws Exception {
                /**
                 * 将工作线程的原有线程变量保存起来
                 */
                Object oldThreadlocalMapOfWorkThread = getCurrentThreadlocalMap();
                /**
                 * 将本线程的线程变量，设置为caller的线程变量
                 */
                setCurrentThreadlocalMap(callerThreadlocalMap);

                try {
                    return targetCallable.call();
                } finally {
                    setCurrentThreadlocalMap(oldThreadlocalMapOfWorkThread);
                    log.info("restore work thread's threadlocal");
                }

            }
        };

        return finalCallable;
    }

    private Object getCurrentThreadlocalMap() {
        Thread thread = Thread.currentThread();
        try {
            Field field = Thread.class.getDeclaredField("threadLocals");
            field.setAccessible(true);
            Object o = field.get(thread);
            return o;
        } catch (NoSuchFieldException | IllegalAccessException e) {
            log.error("{}", e);
        }
        return null;
    }

    private void setCurrentThreadlocalMap(Object newThreadLocalMap) {
        Thread thread = Thread.currentThread();
        try {
            Field field = Thread.class.getDeclaredField("threadLocals");
            field.setAccessible(true);
            field.set(thread, newThreadLocalMap);

        } catch (NoSuchFieldException | IllegalAccessException e) {
            log.error("{}", e);
        }
    }

}
